import logging
import sys
import numpy as np
import torch
import os
import random
import imageio
from tqdm import trange
import re


def str2bool(value):
    if value.lower() in ("yes", "true", "t", "1"):
        return True
    elif value.lower() in ("no", "false", "f", "0"):
        return False
    else:
        raise ValueError("Boolean value expected.")


# def set_seed_everywhere(seed):
#     torch.manual_seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed_all(seed)
#     np.random.seed(seed)
#     random.seed(seed)


def set_seed_everywhere(seed: int, using_cuda: bool = True) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if using_cuda:
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    os.environ["PYTHONHASHSEED"] = str(seed)


# def get_logger(logger_name: str, log_file_path: str) -> logging.Logger:
#     logger = logging.getLogger(logger_name)
#     logger.setLevel(logging.DEBUG)

#     # console handler
#     console_handler = logging.StreamHandler()
#     console_handler.setLevel(logging.DEBUG)

#     # file handler
#     file_handler = logging.FileHandler(log_file_path, encoding='utf-8')
#     file_handler.setLevel(logging.DEBUG)

#     formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

#     console_handler.setFormatter(formatter)
#     file_handler.setFormatter(formatter)

#     logger.addHandler(console_handler)
#     logger.addHandler(file_handler)
#     return logger


class FileLogger:
    def __init__(self, filename, mode="w"):
        self.file = open(filename, mode, encoding="utf-8")
        self.stdout = sys.stdout
        self.stderr = sys.stderr

    def write(self, data):
        self.file.write(data)
        self.stdout.write(data)

    def flush(self):
        self.file.flush()
        self.stdout.flush()

    def close(self):
        self.file.close()


def extract_code(response: str):
    patterns = [
        r"```python(.*?)```",
        r"```(.*?)```",
        r'"""(.*?)"""',
        r'""(.*?)""',
        r'"(.*?)"',
    ]
    code_content = "No Python Code!"
    for pattern in patterns:
        match = re.search(pattern, response, re.DOTALL)
        if match:
            code_content = match.group(1)
            break
    # if "antmaze" in env_name:
    #     # goal is the tuple of goal position (x, y)
    #     code_content = code_content.replace(") -> float:", ", goal) -> float:")

    #     code_content = code_content.replace("obs[29]", "goal[0]")
    #     code_content = code_content.replace("obs[30]", "goal[1]")

    #     code_content = code_content.replace("next_obs[29]", "goal[0]")
    #     code_content = code_content.replace("next_obs[30]", "goal[1]")
    return code_content


def standardize_rewards(array):
    array = np.array(array)
    mean, std = array.mean(), array.std()
    if std == 0:
        return array
    else:
        return (array - mean) / std


def min_max_normalize(array):
    min_val = np.min(array)
    max_val = np.max(array)
    if max_val == min_val:
        return np.zeros_like(array)
    return (array - min_val) / (max_val - min_val)

def get_antmaze_fix_goals(env_name, batch_size):
    goals = None
    if "antmaze-large" in env_name:
        goals = np.full((batch_size, 2), [32.0, 24.0])
    elif "antmaze-medium" in env_name:
        goals = np.full((batch_size, 2), [20.0, 20.0])
    elif "antmaze-umaze" in env_name:
        goals = np.full((batch_size, 2), [0.0, 8.0])
    # if env_name == "antmaze-large-diverse-v0":
    #     goals = np.full((batch_size, 2), [32.67495803, 24.75409971])
    # elif env_name == "antmaze-large-play-v0":
    #     goals = np.full((batch_size, 2), [32.6390761, 24.60561743])
    # elif env_name == "antmaze-medium-diverse-v0":
    #     goals = np.full((batch_size, 2), [20.66263452, 20.79289385])
    # elif env_name == "antmaze-medium-play-v0":
    #     goals = np.full((batch_size, 2), [20.7439014, 20.71133478])
    # elif env_name == "antmaze-umaze-diverse-v0":
    #     goals = np.full((batch_size, 2), [0.53155234, 8.68389583])
    # elif env_name == "antmaze-umaze-v0":
    #     goals = np.full((batch_size, 2), [0.55228592, 8.70199509])
    else:
        raise NotImplementedError("="*10 + f"Environment {env_name} does not support goals." + "="*10)
    # terminals_float = 1 - replay_buffer.not_done
    # goals = np.concatenate([goals, terminals_float], axis=1)
    return goals